import torch
import torch.nn as nn

class Task_Classification(nn.Module):
    def __init__(self, num_features, num_classes=10):
        super(Task_Classification, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(num_features, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

class Task_Classification_L1(nn.Module):
    def __init__(self, num_features, num_classes=10):
        super(Task_Classification_L1, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(num_features, num_classes),
        )
        print("L1 classifier defined")

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x